/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.checkpointing;

import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.CheckpointListener;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
import org.apache.flink.runtime.state.AbstractStateBackend;
import org.apache.flink.runtime.state.filesystem.FsStateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.util.TestLogger;

import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

/**
 * A simple test that runs a streaming topology with checkpointing enabled.
 *
 * <p>The test triggers a failure after a while and verifies that, after completion, the state
 * reflects the "exactly once" semantics.
 *
 * <p>It is designed to check partitioned states.
 */
@SuppressWarnings("serial")
public class KeyedStateCheckpointingITCase extends TestLogger {

    protected static final int MAX_MEM_STATE_SIZE = 10 * 1024 * 1024;

    protected static final int NUM_STRINGS = 10_000;
    protected static final int NUM_KEYS = 40;

    protected static final int NUM_TASK_MANAGERS = 2;
    protected static final int NUM_TASK_SLOTS = 2;
    protected static final int PARALLELISM = NUM_TASK_MANAGERS * NUM_TASK_SLOTS;

    // ------------------------------------------------------------------------

    @ClassRule
    public static final MiniClusterWithClientResource MINI_CLUSTER_RESOURCE =
            new MiniClusterWithClientResource(
                    new MiniClusterResourceConfiguration.Builder()
                            .setConfiguration(getConfiguration())
                            .setNumberTaskManagers(NUM_TASK_MANAGERS)
                            .setNumberSlotsPerTaskManager(NUM_TASK_SLOTS)
                            .build());

    private static Configuration getConfiguration() {
        Configuration config = new Configuration();
        config.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, MemorySize.parse("12m"));
        return config;
    }

    // ------------------------------------------------------------------------

    @Rule public final TemporaryFolder tmpFolder = new TemporaryFolder();

    @Test
    public void testWithMemoryBackendSync() throws Exception {
        MemoryStateBackend syncMemBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, false);
        testProgramWithBackend(syncMemBackend);
    }

    @Test
    public void testWithMemoryBackendAsync() throws Exception {
        MemoryStateBackend asyncMemBackend = new MemoryStateBackend(MAX_MEM_STATE_SIZE, true);
        testProgramWithBackend(asyncMemBackend);
    }

    @Test
    public void testWithFsBackendSync() throws Exception {
        FsStateBackend syncFsBackend =
                new FsStateBackend(tmpFolder.newFolder().toURI().toString(), false);
        testProgramWithBackend(syncFsBackend);
    }

    @Test
    public void testWithFsBackendAsync() throws Exception {
        FsStateBackend asyncFsBackend =
                new FsStateBackend(tmpFolder.newFolder().toURI().toString(), true);
        testProgramWithBackend(asyncFsBackend);
    }

    @Test
    public void testWithRocksDbBackendFull() throws Exception {
        RocksDBStateBackend fullRocksDbBackend =
                new RocksDBStateBackend(new MemoryStateBackend(MAX_MEM_STATE_SIZE), false);
        fullRocksDbBackend.setDbStoragePath(tmpFolder.newFolder().getAbsolutePath());

        testProgramWithBackend(fullRocksDbBackend);
    }

    @Test
    public void testWithRocksDbBackendIncremental() throws Exception {
        RocksDBStateBackend incRocksDbBackend =
                new RocksDBStateBackend(new MemoryStateBackend(MAX_MEM_STATE_SIZE), true);
        incRocksDbBackend.setDbStoragePath(tmpFolder.newFolder().getAbsolutePath());

        testProgramWithBackend(incRocksDbBackend);
    }

    // ------------------------------------------------------------------------

    protected void testProgramWithBackend(AbstractStateBackend stateBackend) throws Exception {
        assertEquals("Broken test setup", 0, (NUM_STRINGS / 2) % NUM_KEYS);

        final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(PARALLELISM);
        env.enableCheckpointing(500);
        env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 0L));

        env.setStateBackend(stateBackend);

        // compute when (randomly) the failure should happen
        final int failurePosMin = (int) (0.6 * NUM_STRINGS / PARALLELISM);
        final int failurePosMax = (int) (0.8 * NUM_STRINGS / PARALLELISM);
        final int failurePos =
                (new Random().nextInt(failurePosMax - failurePosMin) + failurePosMin);

        final DataStream<Integer> stream1 =
                env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2, NUM_STRINGS / 4));

        final DataStream<Integer> stream2 =
                env.addSource(new IntGeneratingSourceFunction(NUM_STRINGS / 2, NUM_STRINGS / 4));

        stream1.union(stream2)
                .keyBy(new IdentityKeySelector<Integer>())
                .map(new OnceFailingPartitionedSum(failurePos))
                .keyBy(0)
                .addSink(new CounterSink());

        env.execute();

        // verify that we counted exactly right
        assertEquals(NUM_KEYS, CounterSink.ALL_COUNTS.size());
        assertEquals(NUM_KEYS, OnceFailingPartitionedSum.ALL_SUMS.size());

        for (Entry<Integer, Long> sum : OnceFailingPartitionedSum.ALL_SUMS.entrySet()) {
            assertEquals((long) sum.getKey() * NUM_STRINGS / NUM_KEYS, sum.getValue().longValue());
        }
        for (long count : CounterSink.ALL_COUNTS.values()) {
            assertEquals(NUM_STRINGS / NUM_KEYS, count);
        }
    }

    // --------------------------------------------------------------------------------------------
    // Custom Functions
    // --------------------------------------------------------------------------------------------

    /**
     * A source that generates a sequence of integers and throttles down until a checkpoint has
     * happened.
     */
    private static class IntGeneratingSourceFunction extends RichParallelSourceFunction<Integer>
            implements ListCheckpointed<Integer>, CheckpointListener {

        private final int numElements;
        private final int checkpointLatestAt;

        private int lastEmitted = -1;

        private boolean checkpointHappened;

        private volatile boolean isRunning = true;

        IntGeneratingSourceFunction(int numElements, int checkpointLatestAt) {
            this.numElements = numElements;
            this.checkpointLatestAt = checkpointLatestAt;
        }

        @Override
        public void run(SourceContext<Integer> ctx) throws Exception {
            final Object lockingObject = ctx.getCheckpointLock();
            final int step = getRuntimeContext().getNumberOfParallelSubtasks();

            int nextElement =
                    lastEmitted >= 0
                            ? lastEmitted + step
                            : getRuntimeContext().getIndexOfThisSubtask();

            while (isRunning && nextElement < numElements) {

                // throttle / block if we are still waiting for the checkpoint
                if (!checkpointHappened) {
                    if (nextElement < checkpointLatestAt) {
                        // only throttle
                        Thread.sleep(1);
                    } else {
                        // hard block
                        synchronized (this) {
                            while (!checkpointHappened) {
                                this.wait();
                            }
                        }
                    }
                }

                //noinspection SynchronizationOnLocalVariableOrMethodParameter
                synchronized (lockingObject) {
                    ctx.collect(nextElement % NUM_KEYS);
                    lastEmitted = nextElement;
                }

                nextElement += step;
            }
        }

        @Override
        public void cancel() {
            isRunning = false;
        }

        @Override
        public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(lastEmitted);
        }

        @Override
        public void restoreState(List<Integer> state) throws Exception {
            assertEquals("Test failed due to unexpected recovered state size", 1, state.size());
            lastEmitted = state.get(0);
            checkpointHappened = true;
        }

        @Override
        public void notifyCheckpointComplete(long checkpointId) throws Exception {
            synchronized (this) {
                checkpointHappened = true;
                this.notifyAll();
            }
        }

        @Override
        public void notifyCheckpointAborted(long checkpointId) {}
    }

    private static class OnceFailingPartitionedSum
            extends RichMapFunction<Integer, Tuple2<Integer, Long>>
            implements ListCheckpointed<Integer> {

        private static final Map<Integer, Long> ALL_SUMS = new ConcurrentHashMap<>();

        private final int failurePos;
        private int count;

        private boolean shouldFail = true;

        private transient ValueState<Long> sum;

        OnceFailingPartitionedSum(int failurePos) {
            this.failurePos = failurePos;
        }

        @Override
        public void open(Configuration parameters) throws IOException {
            sum = getRuntimeContext().getState(new ValueStateDescriptor<>("my_state", Long.class));
        }

        @Override
        public Tuple2<Integer, Long> map(Integer value) throws Exception {
            if (shouldFail && count++ >= failurePos) {
                shouldFail = false;
                throw new Exception("Test Failure");
            }

            Long oldSum = sum.value();
            long currentSum = (oldSum == null ? 0L : oldSum) + value;

            sum.update(currentSum);
            ALL_SUMS.put(value, currentSum);
            return new Tuple2<>(value, currentSum);
        }

        @Override
        public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(count);
        }

        @Override
        public void restoreState(List<Integer> state) throws Exception {
            assertEquals("Test failed due to unexpected recovered state size", 1, state.size());
            count = state.get(0);
            shouldFail = false;
        }

        @Override
        public void close() throws Exception {
            if (shouldFail) {
                fail("Test ineffective: Function cleanly finished without ever failing.");
            }
        }
    }

    private static class CounterSink extends RichSinkFunction<Tuple2<Integer, Long>> {

        private static final Map<Integer, Long> ALL_COUNTS = new ConcurrentHashMap<>();

        private transient ValueState<NonSerializableLong> aCounts;
        private transient ValueState<Long> bCounts;

        @Override
        public void open(Configuration parameters) throws IOException {
            aCounts =
                    getRuntimeContext()
                            .getState(new ValueStateDescriptor<>("a", NonSerializableLong.class));
            bCounts = getRuntimeContext().getState(new ValueStateDescriptor<>("b", Long.class));
        }

        @Override
        public void invoke(Tuple2<Integer, Long> value) throws Exception {
            final NonSerializableLong acRaw = aCounts.value();
            final Long bcRaw = bCounts.value();

            final long ac = acRaw == null ? 0L : acRaw.value;
            final long bc = bcRaw == null ? 0L : bcRaw;

            assertEquals(ac, bc);

            long currentCount = ac + 1;
            aCounts.update(NonSerializableLong.of(currentCount));
            bCounts.update(currentCount);

            ALL_COUNTS.put(value.f0, currentCount);
        }
    }

    private static class IdentityKeySelector<T> implements KeySelector<T, T> {

        @Override
        public T getKey(T value) throws Exception {
            return value;
        }
    }

    // ------------------------------------------------------------------------
    //  data types
    // ------------------------------------------------------------------------

    /** Custom boxed long type that does not implement Serializable. */
    public static class NonSerializableLong {

        public long value;

        private NonSerializableLong(long value) {
            this.value = value;
        }

        public static NonSerializableLong of(long value) {
            return new NonSerializableLong(value);
        }

        @Override
        public boolean equals(Object obj) {
            return this == obj
                    || obj != null
                            && obj.getClass() == getClass()
                            && ((NonSerializableLong) obj).value == this.value;
        }

        @Override
        public int hashCode() {
            return (int) (value ^ (value >>> 32));
        }
    }
}
